# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Lowers break statements to conditionals."""

from nvidia.dali._autograph.core import converter
from nvidia.dali._autograph.pyct import anno
from nvidia.dali._autograph.pyct import qual_names
from nvidia.dali._autograph.pyct import templates
from nvidia.dali._autograph.pyct.static_analysis import activity
from nvidia.dali._autograph.pyct.static_analysis.annos import NodeAnno


class _Break(object):
    def __init__(self):
        self.used = False
        self.control_var_name = None

    def __repr__(self):
        return "used: %s, var: %s" % (self.used, self.control_var_name)


class BreakTransformer(converter.Base):
    """Canonicalizes break statements into additional conditionals."""

    def visit_Break(self, node):
        self.state[_Break].used = True
        var_name = self.state[_Break].control_var_name
        # TODO(mdan): This will fail when expanded inside a top-level else block.
        template = """
      var_name = True
      continue
    """
        return templates.replace(template, var_name=var_name)

    def _guard_if_present(self, block, var_name):
        """Prevents the block from executing if var_name is set."""
        if not block:
            return block

        template = """
        if not var_name:
          block
      """
        node = templates.replace(template, var_name=var_name, block=block)
        return node

    def _process_body(self, nodes, break_var):
        self.state[_Break].enter()
        self.state[_Break].control_var_name = break_var
        nodes = self.visit_block(nodes)
        break_used = self.state[_Break].used
        self.state[_Break].exit()
        return nodes, break_used

    def visit_While(self, node):
        original_node = node
        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        break_var = self.ctx.namer.new_symbol("break_", scope.referenced)

        node.test = self.visit(node.test)
        node.body, break_used = self._process_body(node.body, break_var)
        # A break in the else clause applies to the containing scope.
        node.orelse = self.visit_block(node.orelse)

        if not break_used:
            template = """
        while test:
          body
        orelse
      """
            node = templates.replace(template, test=node.test, body=node.body, orelse=node.orelse)

            new_while_node = node[0]
            anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES)

            return node

        # Python's else clause only triggers if the loop exited cleanly (e.g.
        # break did not trigger).
        guarded_orelse = self._guard_if_present(node.orelse, break_var)

        template = """
      var_name = False
      while not var_name and test:
        body
      orelse
    """
        node = templates.replace(
            template, var_name=break_var, test=node.test, body=node.body, orelse=guarded_orelse
        )

        new_while_node = node[1]
        anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES)

        return node

    def visit_For(self, node):
        original_node = node
        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        break_var = self.ctx.namer.new_symbol("break_", scope.referenced)

        node.target = self.visit(node.target)
        node.iter = self.visit(node.iter)
        node.body, break_used = self._process_body(node.body, break_var)
        # A break in the else clause applies to the containing scope.
        node.orelse = self.visit_block(node.orelse)

        if not break_used:
            template = """
        for target in iter_:
          body
        orelse
      """
            node = templates.replace(
                template, iter_=node.iter, target=node.target, body=node.body, orelse=node.orelse
            )

            new_for_node = node[0]
            anno.copyanno(original_node, new_for_node, anno.Basic.EXTRA_LOOP_TEST)
            anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES)

            return node

        # Python's else clause only triggers if the loop exited cleanly (e.g.
        # break did not trigger).
        guarded_orelse = self._guard_if_present(node.orelse, break_var)
        extra_test = templates.replace_as_expression("not var_name", var_name=break_var)

        # The extra test is hidden in the AST, which will confuse the static
        # analysis. To mitigate that, we insert a no-op statement that ensures
        # the control variable is marked as used.
        # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
        template = """
      var_name = False
      for target in iter_:
        (var_name,)
        body
      orelse
    """
        node = templates.replace(
            template,
            var_name=break_var,
            iter_=node.iter,
            target=node.target,
            body=node.body,
            orelse=guarded_orelse,
        )

        new_for_node = node[1]
        anno.setanno(new_for_node, anno.Basic.EXTRA_LOOP_TEST, extra_test)
        anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES)

        return node


def transform(node, ctx):
    node = qual_names.resolve(node)
    node = activity.resolve(node, ctx, None)

    transformer = BreakTransformer(ctx)
    node = transformer.visit(node)
    return node
